/*
 * relation_restriction_equivalence.c
 *
 * This file contains functions helper functions for planning
 * queries with colocated tables and subqueries.
 *
 * Copyright (c) Citus Data, Inc.
 *
 *-------------------------------------------------------------------------
 */
#include "postgres.h"

#include "catalog/pg_type.h"
#include "nodes/makefuncs.h"
#include "nodes/nodeFuncs.h"
#include "nodes/relation.h"
#include "nodes/pg_list.h"
#include "nodes/primnodes.h"
#include "optimizer/planner.h"
#include "optimizer/pathnode.h"
#include "optimizer/paths.h"
#include "optimizer/var.h"
#include "parser/parsetree.h"

#include "pg_version_constants.h"

#include "distributed/colocation_utils.h"
#include "distributed/distributed_planner.h"
#include "distributed/listutils.h"
#include "distributed/metadata_cache.h"
#include "distributed/multi_logical_optimizer.h"
#include "distributed/multi_logical_planner.h"
#include "distributed/multi_router_planner.h"
#include "distributed/pg_dist_partition.h"
#include "distributed/query_utils.h"
#include "distributed/relation_restriction_equivalence.h"
#include "distributed/shard_pruning.h"
#include "distributed/session_ctx.h"

#include "type_cast.h"

/*
 * AttributeEquivalenceClass
 *
 * Whenever we find an equality clause A = B, where both A and B originates from
 * relation attributes (i.e., not random expressions), we create an
 * AttributeEquivalenceClass to record this knowledge. If we later find another
 * equivalence B = C, we create another AttributeEquivalenceClass. Finally, we can
 * apply transitivity rules and generate a new AttributeEquivalenceClass which includes
 * A, B and C.
 *
 * Note that equality among the members are identified by the varattno and rteIdentity.
 */
typedef struct AttributeEquivalenceClass {
    uint32 equivalenceId;
    List* equivalentAttributes;

    Index unionQueryPartitionKeyIndex;
} AttributeEquivalenceClass;

typedef struct FindQueryContainingRteIdentityContext {
    int targetRTEIdentity;
    Query* query;
} FindQueryContainingRteIdentityContext;

/*
 *  AttributeEquivalenceClassMember - one member expression of an
 *  AttributeEquivalenceClass. The important thing to consider is that
 *  the class member contains "rteIndentity" field. Note that each RTE_RELATION
 *  is assigned a unique rteIdentity in AssignRTEIdentities() function.
 *
 *  "varno" and "varattno" is directly used from a Var clause that is being added
 *  to the attribute equivalence. Since we only use this class for relations, the member
 *  also includes the relation id field.
 */
typedef struct AttributeEquivalenceClassMember {
    Oid relationId;
    int rteIdentity;
    Index varno;
    AttrNumber varattno;
} AttributeEquivalenceClassMember;

static bool ContextContainsLocalRelation(RelationRestrictionContext* restrictionContext);
static bool ContextContainsAppendRelation(RelationRestrictionContext* restrictionContext);
static int RangeTableOffsetCompat(PlannerInfo* root, AppendRelInfo* appendRelInfo);
static Var* FindUnionAllVar(PlannerInfo* root, List* translatedVars, Oid relationOid,
                            Index relationRteIndex, Index* partitionKeyIndex);
static bool ContainsMultipleDistributedRelations(
    PlannerRestrictionContext* plannerRestrictionContext);
static List* GenerateAttributeEquivalencesForRelationRestrictions(
    RelationRestrictionContext* restrictionContext);
static AttributeEquivalenceClass* AttributeEquivalenceClassForEquivalenceClass(
    EquivalenceClass* plannerEqClass, RelationRestriction* relationRestriction);
static void AddToAttributeEquivalenceClass(
    AttributeEquivalenceClass* attributeEquivalenceClass, PlannerInfo* root,
    Var* varToBeAdded);
static void AddRteSubqueryToAttributeEquivalenceClass(
    AttributeEquivalenceClass* attributeEquivalenceClass, RangeTblEntry* rangeTableEntry,
    PlannerInfo* root, Var* varToBeAdded);
static Query* GetTargetSubquery(PlannerInfo* root, RangeTblEntry* rangeTableEntry,
                                Var* varToBeAdded);
static void AddUnionAllSetOperationsToAttributeEquivalenceClass(
    AttributeEquivalenceClass* attributeEquivalenceClass, PlannerInfo* root,
    Var* varToBeAdded);
static void AddUnionSetOperationsToAttributeEquivalenceClass(
    AttributeEquivalenceClass* attributeEquivalenceClass, PlannerInfo* root,
    SetOperationStmt* setOperation, Var* varToBeAdded);
static void AddRteRelationToAttributeEquivalenceClass(
    AttributeEquivalenceClass* attrEquivalenceClass, RangeTblEntry* rangeTableEntry,
    Var* varToBeAdded);
static Var* GetVarFromAssignedParam(List* outerPlanParamsList, Param* plannerParam,
                                    PlannerInfo** rootContainingVar);
static Var* SearchPlannerParamList(List* plannerParamList, Param* plannerParam);
static List* GenerateAttributeEquivalencesForJoinRestrictions(
    JoinRestrictionContext* joinRestrictionContext);
static bool AttributeClassContainsAttributeClassMember(
    AttributeEquivalenceClassMember* inputMember,
    AttributeEquivalenceClass* attributeEquivalenceClass);
static List* AddAttributeClassToAttributeClassList(
    List* attributeEquivalenceList, AttributeEquivalenceClass* attributeEquivalence);
static bool AttributeEquivalencesAreEqual(
    AttributeEquivalenceClass* firstAttributeEquivalence,
    AttributeEquivalenceClass* secondAttributeEquivalence);
static AttributeEquivalenceClass* GenerateCommonEquivalence(
    List* attributeEquivalenceList,
    RelationRestrictionContext* relationRestrictionContext);
static AttributeEquivalenceClass* GenerateEquivalenceClassForRelationRestriction(
    RelationRestrictionContext* relationRestrictionContext);
static void ListConcatUniqueAttributeClassMemberLists(
    AttributeEquivalenceClass* firstClass, AttributeEquivalenceClass* secondClass);
static Var* PartitionKeyForRTEIdentityInQuery(Query* query, int targetRTEIndex,
                                              Index* partitionKeyIndex);
static bool AllDistributedRelationsInRestrictionContextColocated(
    RelationRestrictionContext* restrictionContext);
static bool IsNotSafeRestrictionToRecursivelyPlan(Node* node);
static bool HasPlaceHolderVar(Node* node);
static JoinRestrictionContext* FilterJoinRestrictionContext(
    JoinRestrictionContext* joinRestrictionContext, Relids queryRteIdentities);
static bool RangeTableArrayContainsAnyRTEIdentities(RangeTblEntry** rangeTableEntries,
                                                    int rangeTableArrayLength,
                                                    Relids queryRteIdentities);
static Relids QueryRteIdentities(Query* queryTree);

static Query* FindQueryContainingRTEIdentity(Query* mainQuery, int rteIndex);
static bool FindQueryContainingRTEIdentityInternal(
    Node* node, FindQueryContainingRteIdentityContext* context);

static int ParentCountPriorToAppendRel(List* appendRelList, AppendRelInfo* appendRelInfo);

/*
 * AllDistributionKeysInQueryAreEqual returns true if either
 *    (i)  there exists join in the query and all relations joined on their
 *         partition keys
 *    (ii) there exists only union set operations and all relations has
 *         partition keys in the same ordinal position in the query
 */
bool AllDistributionKeysInQueryAreEqual(
    Query* originalQuery, PlannerRestrictionContext* plannerRestrictionContext)
{
    /* we don't support distribution key equality checks for CTEs yet */
    if (originalQuery->cteList != NIL) {
        return false;
    }

    /* we don't support distribution key equality checks for local tables */
    RelationRestrictionContext* restrictionContext =
        plannerRestrictionContext->relationRestrictionContext;
    if (ContextContainsLocalRelation(restrictionContext)) {
        return false;
    }

    bool restrictionEquivalenceForPartitionKeys =
        RestrictionEquivalenceForPartitionKeys(plannerRestrictionContext);
    if (restrictionEquivalenceForPartitionKeys) {
        return true;
    }

    if (originalQuery->setOperations || ContainsUnionSubquery(originalQuery)) {
        return SafeToPushdownUnionSubquery(originalQuery, plannerRestrictionContext);
    }

    return false;
}

/*
 * ContextContainsLocalRelation determines whether the given
 * RelationRestrictionContext contains any local tables.
 */
static bool ContextContainsLocalRelation(RelationRestrictionContext* restrictionContext)
{
    ListCell* relationRestrictionCell = NULL;

    foreach (relationRestrictionCell, restrictionContext->relationRestrictionList) {
        RelationRestriction* relationRestriction =
            static_cast<RelationRestriction*>(lfirst(relationRestrictionCell));

        if (!relationRestriction->citusTable) {
            return true;
        }
    }

    return false;
}

/*
 * ContextContainsAppendRelation determines whether the given
 * RelationRestrictionContext contains any append-distributed tables.
 */
static bool ContextContainsAppendRelation(RelationRestrictionContext* restrictionContext)
{
    ListCell* relationRestrictionCell = NULL;

    foreach (relationRestrictionCell, restrictionContext->relationRestrictionList) {
        RelationRestriction* relationRestriction =
            static_cast<RelationRestriction*>(lfirst(relationRestrictionCell));

        if (IsCitusTableType(relationRestriction->relationId, APPEND_DISTRIBUTED)) {
            return true;
        }
    }

    return false;
}

/*
 * SafeToPushdownUnionSubquery returns true if all the relations are returns
 * partition keys in the same ordinal position and there is no reference table
 * exists.
 *
 * Note that the function expects (and asserts) the input query to be a top
 * level union query defined by TopLevelUnionQuery().
 *
 * Lastly, the function fails to produce correct output if the target lists contains
 * multiple partition keys on the target list such as the following:
 *
 *   select count(*) from (
 *       select user_id, user_id from users_table
 *   union
 *       select 2, user_id from users_table) u;
 *
 * For the above query, although the second item in the target list make this query
 * safe to push down, the function would fail to return true.
 */
bool SafeToPushdownUnionSubquery(Query* originalQuery,
                                 PlannerRestrictionContext* plannerRestrictionContext)
{
    RelationRestrictionContext* restrictionContext =
        plannerRestrictionContext->relationRestrictionContext;
    JoinRestrictionContext* joinRestrictionContext =
        plannerRestrictionContext->joinRestrictionContext;

    AttributeEquivalenceClass* attributeEquivalence =
        static_cast<AttributeEquivalenceClass*>(
            palloc0(sizeof(AttributeEquivalenceClass)));
    ListCell* relationRestrictionCell = NULL;

    attributeEquivalence->equivalenceId = Session_ctx::PlanCtx().AttributeEquivalenceId++;

    /*
     * Ensure that the partition column is in the same place across all
     * leaf queries in the UNION and construct an equivalence class for
     * these columns.
     */
    foreach (relationRestrictionCell, restrictionContext->relationRestrictionList) {
        RelationRestriction* relationRestriction =
            static_cast<RelationRestriction*>(lfirst(relationRestrictionCell));
        Index partitionKeyIndex = InvalidAttrNumber;
        PlannerInfo* relationPlannerRoot = relationRestriction->plannerInfo;

        int targetRTEIndex = GetRTEIdentity(relationRestriction->rte);
        Var* varToBeAdded = PartitionKeyForRTEIdentityInQuery(
            originalQuery, targetRTEIndex, &partitionKeyIndex);

        /* union does not have partition key in the target list */
        if (partitionKeyIndex == 0) {
            continue;
        }

        /*
         * This should never happen but to be on the safe side, we have this
         */
        if (relationPlannerRoot->simple_rel_array_size < relationRestriction->index) {
            continue;
        }

        /*
         * We update the varno because we use the original parse tree for finding the
         * var. However the rest of the code relies on a query tree that might be
         * different than the original parse tree because of postgres optimizations.
         * That's why we update the varno to reflect the rteIndex in the modified query
         * tree.
         */
        varToBeAdded->varno = relationRestriction->index;

        /*
         * The current relation does not have its partition key in the target list.
         */
        if (partitionKeyIndex == InvalidAttrNumber) {
            continue;
        }

        /*
         * We find the first relations partition key index in the target list. Later,
         * we check whether all the relations have partition keys in the
         * same position.
         */
        if (attributeEquivalence->unionQueryPartitionKeyIndex == InvalidAttrNumber) {
            attributeEquivalence->unionQueryPartitionKeyIndex = partitionKeyIndex;
        } else if (attributeEquivalence->unionQueryPartitionKeyIndex !=
                   partitionKeyIndex) {
            continue;
        }

        Assert(varToBeAdded != NULL);
        AddToAttributeEquivalenceClass(attributeEquivalence, relationPlannerRoot,
                                       varToBeAdded);
    }

    /*
     * For queries of the form:
     * (SELECT ... FROM a JOIN b ...) UNION (SELECT .. FROM c JOIN d ... )
     *
     * we determine whether all relations are joined on the partition column
     * by adding the equivalence classes that can be inferred from joins.
     */
    List* relationRestrictionAttributeEquivalenceList =
        GenerateAttributeEquivalencesForRelationRestrictions(restrictionContext);
    List* joinRestrictionAttributeEquivalenceList =
        GenerateAttributeEquivalencesForJoinRestrictions(joinRestrictionContext);

    List* allAttributeEquivalenceList =
        list_concat(relationRestrictionAttributeEquivalenceList,
                    joinRestrictionAttributeEquivalenceList);

    allAttributeEquivalenceList =
        lappend(allAttributeEquivalenceList, attributeEquivalence);

    if (!EquivalenceListContainsRelationsEquality(allAttributeEquivalenceList,
                                                  restrictionContext)) {
        /* cannot confirm equality for all distribution colums */
        return false;
    }

    if (!AllDistributedRelationsInRestrictionContextColocated(restrictionContext)) {
        /* distribution columns are equal, but tables are not co-located */
        return false;
    }

    return true;
}

/*
 * RangeTableOffsetCompat returns the range table offset(in glob->finalrtable) for the
 * appendRelInfo.
 */
static int RangeTableOffsetCompat(PlannerInfo* root, AppendRelInfo* appendRelInfo)
{
    int parentCount = ParentCountPriorToAppendRel(root->append_rel_list, appendRelInfo);
    int skipParentCount = parentCount - 1;

    int i = 1;
    for (; i < root->simple_rel_array_size; i++) {
        RangeTblEntry* rte = root->simple_rte_array[i];
        if (rte->inh) {
            /*
             * We skip the previous parents because we want to find the offset
             * for the given append rel info.
             */
            if (skipParentCount > 0) {
                skipParentCount--;
                continue;
            }
            break;
        }
    }
    int indexInRtable = (i - 1);

    /*
     * Postgres adds the global rte array size to parent_relid as an offset.
     * Here we do the reverse operation: Commit on postgres side:
     * 6ef77cf46e81f45716ec981cb08781d426181378
     */
    int parentRelIndex = appendRelInfo->parent_relid - 1;
    return parentRelIndex - indexInRtable;
}

/*
 * FindUnionAllVar finds the variable used in union all for the side that has
 * relationRteIndex as its index and the same varattno as the partition key of
 * the given relation with relationOid.
 */
static Var* FindUnionAllVar(PlannerInfo* root, List* translatedVars, Oid relationOid,
                            Index relationRteIndex, Index* partitionKeyIndex)
{
    if (!IsCitusTableType(relationOid, STRICTLY_PARTITIONED_DISTRIBUTED_TABLE)) {
        /* we only care about hash and range partitioned tables */
        *partitionKeyIndex = 0;
        return NULL;
    }

    Var* relationPartitionKey = DistPartitionKeyOrError(relationOid);

    AttrNumber childAttrNumber = 0;
    *partitionKeyIndex = 0;
    ListCell* translatedVarCell;
    foreach (translatedVarCell, translatedVars) {
        Node* targetNode = (Node*)lfirst(translatedVarCell);
        childAttrNumber++;

        if (!IsA(targetNode, Var)) {
            continue;
        }

        Var* targetVar = (Var*)lfirst(translatedVarCell);
        if (targetVar->varno == relationRteIndex &&
            targetVar->varattno == relationPartitionKey->varattno) {
            *partitionKeyIndex = childAttrNumber;

            return targetVar;
        }
    }
    return NULL;
}

/*
 * RestrictionEquivalenceForPartitionKeys aims to deduce whether each of the RTE_RELATION
 * is joined with at least one another RTE_RELATION on their partition keys. If each
 * RTE_RELATION follows the above rule, we can conclude that all RTE_RELATIONs are
 * joined on their partition keys.
 *
 * Before doing the expensive equality checks, we do a cheaper check to understand
 * whether there are more than one distributed relations. Otherwise, we exit early.
 *
 * The function returns true if all relations are joined on their partition keys.
 * Otherwise, the function returns false. We ignore reference tables at all since
 * they don't have partition keys.
 *
 * In order to do that, we invented a new equivalence class namely:
 * AttributeEquivalenceClass. In very simple words, a AttributeEquivalenceClass is
 * identified by an unique id and consists of a list of AttributeEquivalenceMembers.
 *
 * Each AttributeEquivalenceMember is designed to identify attributes uniquely within the
 * whole query. The necessity of this arise since varno attributes are defined within
 * a single level of a query. Instead, here we want to identify each RTE_RELATION uniquely
 * and try to find equality among each RTE_RELATION's partition key.
 *
 * Each equality among RTE_RELATION is saved using an AttributeEquivalenceClass where
 * each member attribute is identified by a AttributeEquivalenceMember. In the final
 * step, we try generate a common attribute equivalence class that holds as much as
 * AttributeEquivalenceMembers whose attributes are a partition keys.
 *
 * RestrictionEquivalenceForPartitionKeys uses both relation restrictions and join
 * restrictions to find as much as information that Postgres planner provides to
 * extensions. For the details of the usage, please see
 * GenerateAttributeEquivalencesForRelationRestrictions() and
 * GenerateAttributeEquivalencesForJoinRestrictions().
 */
bool RestrictionEquivalenceForPartitionKeys(PlannerRestrictionContext* restrictionContext)
{
    if (ContextContainsLocalRelation(restrictionContext->relationRestrictionContext)) {
        return false;
    } else if (!ContainsMultipleDistributedRelations(restrictionContext)) {
        /* there is a single distributed relation, no need to continue */
        return true;
    } else if (ContextContainsAppendRelation(
                   restrictionContext->relationRestrictionContext)) {
        /* we never consider append-distributed tables co-located */
        return false;
    }

    List* attributeEquivalenceList = GenerateAllAttributeEquivalences(restrictionContext);

    return RestrictionEquivalenceForPartitionKeysViaEquivalences(
        restrictionContext, attributeEquivalenceList);
}

/*
 * RestrictionEquivalenceForPartitionKeysViaEquivalences follows the same rules
 * with RestrictionEquivalenceForPartitionKeys(). The only difference is that
 * this function allows passing pre-computed attribute equivalences along with
 * the planner restriction context.
 */
bool RestrictionEquivalenceForPartitionKeysViaEquivalences(
    PlannerRestrictionContext* plannerRestrictionContext,
    List* allAttributeEquivalenceList)
{
    RelationRestrictionContext* restrictionContext =
        plannerRestrictionContext->relationRestrictionContext;

    /* there is a single distributed relation, no need to continue */
    if (!ContainsMultipleDistributedRelations(plannerRestrictionContext)) {
        return true;
    }

    return EquivalenceListContainsRelationsEquality(allAttributeEquivalenceList,
                                                    restrictionContext);
}

/*
 * ContainsMultipleDistributedRelations returns true if the input planner
 * restriction context contains more than one distributed relation.
 */
static bool ContainsMultipleDistributedRelations(
    PlannerRestrictionContext* plannerRestrictionContext)
{
    RelationRestrictionContext* restrictionContext =
        plannerRestrictionContext->relationRestrictionContext;

    uint32 distributedRelationCount =
        UniqueRelationCount(restrictionContext, DISTRIBUTED_TABLE);

    /*
     * If the query includes a single relation which is not a reference table,
     * we should not check the partition column equality.
     * Consider two example cases:
     *   (i)   The query includes only a single colocated relation
     *   (ii)  A colocated relation is joined with a (or multiple) reference
     *         table(s) where colocated relation is not joined on the partition key
     *
     * For the above two cases, we don't need to execute the partition column equality
     * algorithm. The reason is that the essence of this function is to ensure that the
     * tasks that are going to be created should not need data from other tasks. In both
     * cases mentioned above, the necessary data per task would be on available.
     */
    if (distributedRelationCount <= 1) {
        return false;
    }

    return true;
}

/*
 * GenerateAllAttributeEquivalences gets the planner restriction context and returns
 * the list of all attribute equivalences based on both join restrictions and relation
 * restrictions.
 */
List* GenerateAllAttributeEquivalences(
    PlannerRestrictionContext* plannerRestrictionContext)
{
    RelationRestrictionContext* relationRestrictionContext =
        plannerRestrictionContext->relationRestrictionContext;
    JoinRestrictionContext* joinRestrictionContext =
        plannerRestrictionContext->joinRestrictionContext;

    /* reset the equivalence id counter per call to prevent overflows */
    Session_ctx::PlanCtx().AttributeEquivalenceId = 1;

    List* relationRestrictionAttributeEquivalenceList =
        GenerateAttributeEquivalencesForRelationRestrictions(relationRestrictionContext);
    List* joinRestrictionAttributeEquivalenceList =
        GenerateAttributeEquivalencesForJoinRestrictions(joinRestrictionContext);

    List* allAttributeEquivalenceList =
        list_concat(relationRestrictionAttributeEquivalenceList,
                    joinRestrictionAttributeEquivalenceList);

    return allAttributeEquivalenceList;
}

/*
 * UniqueRelationCount iterates over the relations and returns the
 * unique relation count. We use RTEIdentity as the identifiers, so if
 * the same relation appears twice in the restrictionContext, we count
 * it as a single item.
 */
uint32 UniqueRelationCount(RelationRestrictionContext* restrictionContext,
                           CitusTableType tableType)
{
    ListCell* relationRestrictionCell = NULL;
    List* rteIdentityList = NIL;

    foreach (relationRestrictionCell, restrictionContext->relationRestrictionList) {
        RelationRestriction* relationRestriction =
            (RelationRestriction*)lfirst(relationRestrictionCell);
        Oid relationId = relationRestriction->relationId;

        CitusTableCacheEntry* cacheEntry = LookupCitusTableCacheEntry(relationId);
        if (cacheEntry == NULL) {
            /* we  don't expect non-distributed tables, still be no harm to skip */
            continue;
        }

        if (IsCitusTableTypeCacheEntry(cacheEntry, tableType)) {
            int rteIdentity = GetRTEIdentity(relationRestriction->rte);
            rteIdentityList = list_append_unique_int(rteIdentityList, rteIdentity);
        }
    }

    return list_length(rteIdentityList);
}

/*
 * EquivalenceListContainsRelationsEquality gets a list of attributed equivalence
 * list and a relation restriction context. The function first generates a common
 * equivalence class out of the attributeEquivalenceList. Later, the function checks
 * whether all the relations exists in the common equivalence class.
 *
 */
bool EquivalenceListContainsRelationsEquality(
    List* attributeEquivalenceList, RelationRestrictionContext* restrictionContext)
{
    ListCell* commonEqClassCell = NULL;
    ListCell* relationRestrictionCell = NULL;
    Relids commonRteIdentities = NULL;

    /*
     * In general we're trying to expand existing the equivalence classes to find a
     * common equivalence class. The main goal is to test whether this main class
     * contains all partition keys of the existing relations.
     */
    AttributeEquivalenceClass* commonEquivalenceClass =
        GenerateCommonEquivalence(attributeEquivalenceList, restrictionContext);

    /* add the rte indexes of relations to a bitmap */
    foreach (commonEqClassCell, commonEquivalenceClass->equivalentAttributes) {
        AttributeEquivalenceClassMember* classMember =
            (AttributeEquivalenceClassMember*)lfirst(commonEqClassCell);
        int rteIdentity = classMember->rteIdentity;

        commonRteIdentities = bms_add_member(commonRteIdentities, rteIdentity);
    }

    /* check whether all relations exists in the main restriction list */
    foreach (relationRestrictionCell, restrictionContext->relationRestrictionList) {
        RelationRestriction* relationRestriction =
            (RelationRestriction*)lfirst(relationRestrictionCell);
        int rteIdentity = GetRTEIdentity(relationRestriction->rte);

        /* we shouldn't check for the equality of non-distributed tables */
        if (IsCitusTable(relationRestriction->relationId) &&
            !HasDistributionKey(relationRestriction->relationId)) {
            continue;
        }

        if (!bms_is_member(rteIdentity, commonRteIdentities)) {
            return false;
        }
    }

    return true;
}

/*
 * GenerateAttributeEquivalencesForRelationRestrictions gets a relation restriction
 * context and returns a list of AttributeEquivalenceClass.
 *
 * The algorithm followed can be summarized as below:
 *
 * - Per relation restriction
 *     - Per plannerInfo's eq_class
 *         - Create an AttributeEquivalenceClass
 *         - Add all Vars that appear in the plannerInfo's
 *           eq_class to the AttributeEquivalenceClass
 *               - While doing that, consider LATERAL vars as well.
 *                 See GetVarFromAssignedParam() for the details. Note
 *                 that we're using parentPlannerInfo while adding the
 *                 LATERAL vars given that we rely on that plannerInfo.
 *
 */
static List* GenerateAttributeEquivalencesForRelationRestrictions(
    RelationRestrictionContext* restrictionContext)
{
    List* attributeEquivalenceList = NIL;
    ListCell* relationRestrictionCell = NULL;

    if (restrictionContext == NULL) {
        return attributeEquivalenceList;
    }

    foreach (relationRestrictionCell, restrictionContext->relationRestrictionList) {
        RelationRestriction* relationRestriction =
            (RelationRestriction*)lfirst(relationRestrictionCell);
        List* equivalenceClasses = relationRestriction->plannerInfo->eq_classes;
        ListCell* equivalenceClassCell = NULL;

        foreach (equivalenceClassCell, equivalenceClasses) {
            EquivalenceClass* plannerEqClass =
                (EquivalenceClass*)lfirst(equivalenceClassCell);

            AttributeEquivalenceClass* attributeEquivalence =
                AttributeEquivalenceClassForEquivalenceClass(plannerEqClass,
                                                             relationRestriction);

            attributeEquivalenceList = AddAttributeClassToAttributeClassList(
                attributeEquivalenceList, attributeEquivalence);
        }
    }

    return attributeEquivalenceList;
}

/*
 * AttributeEquivalenceClassForEquivalenceClass is a helper function for
 * GenerateAttributeEquivalencesForRelationRestrictions. The function takes an
 * EquivalenceClass and the relation restriction that the equivalence class
 * belongs to. The function returns an AttributeEquivalenceClass that is composed
 * of ec_members that are simple Var references.
 *
 * The function also takes case of LATERAL joins by simply replacing the PARAM_EXEC
 * with the corresponding expression.
 */
static AttributeEquivalenceClass* AttributeEquivalenceClassForEquivalenceClass(
    EquivalenceClass* plannerEqClass, RelationRestriction* relationRestriction)
{
    AttributeEquivalenceClass* attributeEquivalence =
        static_cast<AttributeEquivalenceClass*>(
            palloc0(sizeof(AttributeEquivalenceClass)));
    ListCell* equivilanceMemberCell = NULL;
    PlannerInfo* plannerInfo = relationRestriction->plannerInfo;

    attributeEquivalence->equivalenceId = Session_ctx::PlanCtx().AttributeEquivalenceId++;

    foreach (equivilanceMemberCell, plannerEqClass->ec_members) {
        EquivalenceMember* equivalenceMember =
            (EquivalenceMember*)lfirst(equivilanceMemberCell);
        Node* equivalenceNode =
            strip_implicit_coercions((Node*)equivalenceMember->em_expr);
        Expr* strippedEquivalenceExpr = (Expr*)equivalenceNode;

        Var* expressionVar = NULL;

        if (IsA(strippedEquivalenceExpr, Param)) {
            PlannerInfo* outerNodeRoot = NULL;
            Param* equivalenceParam = (Param*)strippedEquivalenceExpr;

            expressionVar =
                GetVarFromAssignedParam(relationRestriction->outerPlanParamsList,
                                        equivalenceParam, &outerNodeRoot);
            if (expressionVar) {
                AddToAttributeEquivalenceClass(attributeEquivalence, outerNodeRoot,
                                               expressionVar);
            }
        } else if (IsA(strippedEquivalenceExpr, Var)) {
            expressionVar = (Var*)strippedEquivalenceExpr;
            AddToAttributeEquivalenceClass(attributeEquivalence, plannerInfo,
                                           expressionVar);
        }
    }

    return attributeEquivalence;
}

/*
 * GetVarFromAssignedParam returns the Var that is assigned to the given
 * plannerParam if its kind is PARAM_EXEC.
 *
 * If the paramkind is not equal to PARAM_EXEC the function returns NULL. Similarly,
 * if there is no Var corresponding to the given param is, the function returns NULL.
 *
 * Rationale behind this function:
 *
 *   While iterating through the equivalence classes of RTE_RELATIONs, we
 *   observe that there are PARAM type of equivalence member expressions for
 *   the RTE_RELATIONs which actually belong to lateral vars from the other query
 *   levels.
 *
 *   We're also keeping track of the RTE_RELATION's outer nodes'
 *   plan_params lists which is expected to hold the parameters that are required
 *   for its lower level queries as it is documented:
 *
 *        plan_params contains the expressions that this query level needs to
 *        make available to a lower query level that is currently being planned.
 *
 *   This function is a helper function to iterate through the outer node's query's
 *   plan_params and looks for the param that the equivalence member has. The
 *   comparison is done via the "paramid" field. Finally, if the found parameter's
 *   item is a Var, we conclude that Postgres standard_planner replaced the Var
 *   with the Param on assign_param_for_var() function
 *   @src/backend/optimizer/plan/subselect.c.
 */
static Var* GetVarFromAssignedParam(List* outerPlanParamsList, Param* plannerParam,
                                    PlannerInfo** rootContainingVar)
{
    Var* assignedVar = NULL;
    ListCell* rootPlanParamsCell = NULL;

    Assert(plannerParam != NULL);

    /* we're only interested in parameters that Postgres added for execution */
    if (plannerParam->paramkind != PARAM_EXEC) {
        return NULL;
    }

    foreach (rootPlanParamsCell, outerPlanParamsList) {
        RootPlanParams* outerPlanParams =
            static_cast<RootPlanParams*>(lfirst(rootPlanParamsCell));

        assignedVar = SearchPlannerParamList(outerPlanParams->plan_params, plannerParam);
        if (assignedVar != NULL) {
            *rootContainingVar = outerPlanParams->root;
            break;
        }
    }

    return assignedVar;
}

/*
 * SearchPlannerParamList searches in plannerParamList and returns the Var that
 * corresponds to the given plannerParam. If there is no Var corresponding to the
 * given param is, the function returns NULL.
 */
static Var* SearchPlannerParamList(List* plannerParamList, Param* plannerParam)
{
    Var* assignedVar = NULL;
    ListCell* plannerParameterCell = NULL;

    foreach (plannerParameterCell, plannerParamList) {
        PlannerParamItem* plannerParamItem =
            (PlannerParamItem*)lfirst(plannerParameterCell);

        if (plannerParamItem->paramId != plannerParam->paramid) {
            continue;
        }

        /* TODO: Should we consider PlaceHolderVar? */
        if (!IsA(plannerParamItem->item, Var)) {
            continue;
        }

        assignedVar = (Var*)plannerParamItem->item;

        break;
    }

    return assignedVar;
}

/*
 * GenerateCommonEquivalence gets a list of unrelated AttributeEquiavalenceClass
 * whose all members are partition keys.
 *
 * With the equivalence classes, the function follows the algorithm
 * outlined below:
 *
 *     - Add the first equivalence class to the common equivalence class
 *     - Then, iterate on the remaining equivalence classes
 *          - If any of the members equal to the common equivalence class
 *            add all the members of the equivalence class to the common
 *            class
 *          - Start the iteration from the beginning. The reason is that
 *            in case any of the classes we've passed is equivalent to the
 *            newly added one. To optimize the algorithm, we utilze the
 *            equivalence class ids and skip the ones that are already added.
 *      - Finally, return the common equivalence class.
 */
static AttributeEquivalenceClass* GenerateCommonEquivalence(
    List* attributeEquivalenceList,
    RelationRestrictionContext* relationRestrictionContext)
{
    Bitmapset* addedEquivalenceIds = NULL;
    uint32 equivalenceListSize = list_length(attributeEquivalenceList);
    uint32 equivalenceClassIndex = 0;

    AttributeEquivalenceClass* commonEquivalenceClass =
        static_cast<AttributeEquivalenceClass*>(
            palloc0(sizeof(AttributeEquivalenceClass)));
    commonEquivalenceClass->equivalenceId = 0;

    /*
     * We seed the common equivalence class with a the first distributed
     * table since we always want the input distributed relations to be
     * on the common class.
     */
    AttributeEquivalenceClass* firstEquivalenceClass =
        GenerateEquivalenceClassForRelationRestriction(relationRestrictionContext);

    /* we skip the calculation if there are not enough information */
    if (equivalenceListSize < 1 || firstEquivalenceClass == NULL) {
        return commonEquivalenceClass;
    }

    commonEquivalenceClass->equivalentAttributes =
        firstEquivalenceClass->equivalentAttributes;
    addedEquivalenceIds =
        bms_add_member(addedEquivalenceIds, firstEquivalenceClass->equivalenceId);

    while (equivalenceClassIndex < equivalenceListSize) {
        ListCell* equivalenceMemberCell = NULL;
        bool restartLoop = false;

        AttributeEquivalenceClass* currentEquivalenceClass =
            static_cast<AttributeEquivalenceClass*>(
                list_nth(attributeEquivalenceList, equivalenceClassIndex));

        /*
         * This is an optimization. If we already added the same equivalence class,
         * we could skip it since we've already added all the relevant equivalence
         * members.
         */
        if (bms_is_member(currentEquivalenceClass->equivalenceId, addedEquivalenceIds)) {
            equivalenceClassIndex++;

            continue;
        }

        foreach (equivalenceMemberCell, currentEquivalenceClass->equivalentAttributes) {
            AttributeEquivalenceClassMember* attributeEquialanceMember =
                (AttributeEquivalenceClassMember*)lfirst(equivalenceMemberCell);

            if (AttributeClassContainsAttributeClassMember(attributeEquialanceMember,
                                                           commonEquivalenceClass)) {
                ListConcatUniqueAttributeClassMemberLists(commonEquivalenceClass,
                                                          currentEquivalenceClass);

                addedEquivalenceIds = bms_add_member(
                    addedEquivalenceIds, currentEquivalenceClass->equivalenceId);

                /*
                 * It seems inefficient to start from the beginning.
                 * But, we should somehow restart from the beginning to test that
                 * whether the already skipped ones are equal or not.
                 */
                restartLoop = true;

                break;
            }
        }

        if (restartLoop) {
            equivalenceClassIndex = 0;
        } else {
            ++equivalenceClassIndex;
        }
    }

    return commonEquivalenceClass;
}

/*
 * GenerateEquivalenceClassForRelationRestriction generates an AttributeEquivalenceClass
 * with a single AttributeEquivalenceClassMember.
 */
static AttributeEquivalenceClass* GenerateEquivalenceClassForRelationRestriction(
    RelationRestrictionContext* relationRestrictionContext)
{
    ListCell* relationRestrictionCell = NULL;
    AttributeEquivalenceClassMember* eqMember = NULL;
    AttributeEquivalenceClass* eqClassForRelation = NULL;

    foreach (relationRestrictionCell,
             relationRestrictionContext->relationRestrictionList) {
        RelationRestriction* relationRestriction =
            (RelationRestriction*)lfirst(relationRestrictionCell);
        Var* relationPartitionKey = DistPartitionKey(relationRestriction->relationId);

        if (relationPartitionKey) {
            eqClassForRelation = static_cast<AttributeEquivalenceClass*>(
                palloc0(sizeof(AttributeEquivalenceClass)));
            eqMember = static_cast<AttributeEquivalenceClassMember*>(
                palloc0(sizeof(AttributeEquivalenceClassMember)));
            eqMember->relationId = relationRestriction->relationId;
            eqMember->rteIdentity = GetRTEIdentity(relationRestriction->rte);
            eqMember->varno = relationRestriction->index;
            eqMember->varattno = relationPartitionKey->varattno;

            eqClassForRelation->equivalentAttributes =
                lappend(eqClassForRelation->equivalentAttributes, eqMember);

            break;
        }
    }

    return eqClassForRelation;
}

/*
 * ListConcatUniqueAttributeClassMemberLists gets two attribute equivalence classes. It
 * basically concatenates attribute equivalence member lists uniquely and updates the
 * firstClass' member list with the list.
 *
 * Basically, the function iterates over the secondClass' member list and checks whether
 * it already exists in the firstClass' member list. If not, the member is added to the
 * firstClass.
 */
static void ListConcatUniqueAttributeClassMemberLists(
    AttributeEquivalenceClass* firstClass, AttributeEquivalenceClass* secondClass)
{
    ListCell* equivalenceClassMemberCell = NULL;
    List* equivalenceMemberList = secondClass->equivalentAttributes;

    foreach (equivalenceClassMemberCell, equivalenceMemberList) {
        AttributeEquivalenceClassMember* newEqMember =
            (AttributeEquivalenceClassMember*)lfirst(equivalenceClassMemberCell);

        if (AttributeClassContainsAttributeClassMember(newEqMember, firstClass)) {
            continue;
        }

        firstClass->equivalentAttributes =
            lappend(firstClass->equivalentAttributes, newEqMember);
    }
}

/*
 * GenerateAttributeEquivalencesForJoinRestrictions gets a join restriction
 * context and returns a list of AttrributeEquivalenceClass.
 *
 * The algorithm followed can be summarized as below:
 *
 * - Per join restriction
 *     - Per RestrictInfo of the join restriction
 *     - Check whether the join restriction is in the form of (Var1 = Var2)
 *         - Create an AttributeEquivalenceClass
 *         - Add both Var1 and Var2 to the AttributeEquivalenceClass
 */
static List* GenerateAttributeEquivalencesForJoinRestrictions(
    JoinRestrictionContext* joinRestrictionContext)
{
    List* attributeEquivalenceList = NIL;
    ListCell* joinRestrictionCell = NULL;

    if (joinRestrictionContext == NULL) {
        return attributeEquivalenceList;
    }

    foreach (joinRestrictionCell, joinRestrictionContext->joinRestrictionList) {
        JoinRestriction* joinRestriction = (JoinRestriction*)lfirst(joinRestrictionCell);
        ListCell* restrictionInfoList = NULL;

        foreach (restrictionInfoList, joinRestriction->joinRestrictInfoList) {
            RestrictInfo* rinfo = (RestrictInfo*)lfirst(restrictionInfoList);
            Expr* restrictionClause = rinfo->clause;

            if (!IsA(restrictionClause, OpExpr)) {
                continue;
            }

            OpExpr* restrictionOpExpr = (OpExpr*)restrictionClause;
            if (list_length(restrictionOpExpr->args) != 2) {
                continue;
            }
            if (!OperatorImplementsEquality(restrictionOpExpr->opno)) {
                continue;
            }

            Node* leftNode = static_cast<Node*>(linitial(restrictionOpExpr->args));
            Node* rightNode = static_cast<Node*>(lsecond(restrictionOpExpr->args));

            /* we also don't want implicit coercions */
            Expr* strippedLeftExpr = (Expr*)strip_implicit_coercions((Node*)leftNode);
            Expr* strippedRightExpr = (Expr*)strip_implicit_coercions((Node*)rightNode);

            if (!(IsA(strippedLeftExpr, Var) && IsA(strippedRightExpr, Var))) {
                continue;
            }

            Var* leftVar = (Var*)strippedLeftExpr;
            Var* rightVar = (Var*)strippedRightExpr;

            AttributeEquivalenceClass* attributeEquivalence =
                static_cast<AttributeEquivalenceClass*>(
                    palloc0(sizeof(AttributeEquivalenceClass)));
            attributeEquivalence->equivalenceId =
                Session_ctx::PlanCtx().AttributeEquivalenceId++;

            AddToAttributeEquivalenceClass(attributeEquivalence,
                                           joinRestriction->plannerInfo, leftVar);

            AddToAttributeEquivalenceClass(attributeEquivalence,
                                           joinRestriction->plannerInfo, rightVar);

            attributeEquivalenceList = AddAttributeClassToAttributeClassList(
                attributeEquivalenceList, attributeEquivalence);
        }
    }

    return attributeEquivalenceList;
}

/*
 * AddToAttributeEquivalenceClass is a key function for building the attribute
 * equivalences. The function gets a plannerInfo, var and attribute equivalence
 * class. It searches for the RTE_RELATION(s) that the input var belongs to and
 * adds the found Var(s) to the input attribute equivalence class.
 *
 * Note that the input var could come from a subquery (i.e., not directly from an
 * RTE_RELATION). That's the reason we recursively call the function until the
 * RTE_RELATION found.
 *
 * The algorithm could be summarized as follows:
 *
 *    - If the RTE that corresponds to a relation
 *        - Generate an AttributeEquivalenceMember and add to the input
 *          AttributeEquivalenceClass
 *    - If the RTE that corresponds to a subquery
 *        - If the RTE that corresponds to a UNION ALL subquery
 *            - Iterate on each of the appendRels (i.e., each of the UNION ALL query)
 *            - Recursively add all children of the set operation's
 *              corresponding target entries
 *        - If the corresponding subquery entry is a UNION set operation
 *             - Recursively add all children of the set operation's
 *               corresponding target entries
 *        - If the corresponding subquery is a regular subquery (i.e., No set operations)
 *             - Recursively try to add the corresponding target entry to the
 *               equivalence class
 */
static void AddToAttributeEquivalenceClass(
    AttributeEquivalenceClass* attributeEquivalenceClass, PlannerInfo* root,
    Var* varToBeAdded)
{
    /* punt if it's a whole-row var rather than a plain column reference */
    if (varToBeAdded->varattno == InvalidAttrNumber) {
        return;
    }

    /* we also don't want to process ctid, tableoid etc */
    if (varToBeAdded->varattno < InvalidAttrNumber) {
        return;
    }

    /* outer join checks in PG16 */
    if (IsRelOptOuterJoin(root, varToBeAdded->varno)) {
        return;
    }

    RangeTblEntry* rangeTableEntry = root->simple_rte_array[varToBeAdded->varno];
    if (rangeTableEntry->rtekind == RTE_RELATION) {
        AddRteRelationToAttributeEquivalenceClass(attributeEquivalenceClass,
                                                  rangeTableEntry, varToBeAdded);
    } else if (rangeTableEntry->rtekind == RTE_SUBQUERY) {
        AddRteSubqueryToAttributeEquivalenceClass(attributeEquivalenceClass,
                                                  rangeTableEntry, root, varToBeAdded);
    }
}

/*
 * AddRteSubqueryToAttributeEquivalenceClass adds the given var to the given
 * attribute equivalence class.
 *
 * The main algorithm is outlined in AddToAttributeEquivalenceClass().
 */
static void AddRteSubqueryToAttributeEquivalenceClass(
    AttributeEquivalenceClass* attributeEquivalenceClass, RangeTblEntry* rangeTableEntry,
    PlannerInfo* root, Var* varToBeAdded)
{
    RelOptInfo* baseRelOptInfo = find_base_rel(root, varToBeAdded->varno);
    Query* targetSubquery = GetTargetSubquery(root, rangeTableEntry, varToBeAdded);

    /*
     * We might not always get the subquery because the subquery might be a
     * referencing to RELOPT_DEADREL such that the corresponding join is
     * removed via join_is_removable().
     *
     * Returning here implies that PostgreSQL doesn't need to plan the
     * subquery because it doesn't contribute to the query result at all.
     * Since the relations in the subquery does not appear in the query
     * plan as well, Citus would simply ignore the subquery and treat that
     * as a safe-to-pushdown subquery.
     */
    if (targetSubquery == NULL) {
        return;
    }

    TargetEntry* subqueryTargetEntry =
        get_tle_by_resno(targetSubquery->targetList, varToBeAdded->varattno);

    /* if we fail to find corresponding target entry, do not proceed */
    if (subqueryTargetEntry == NULL || subqueryTargetEntry->resjunk) {
        return;
    }

    /* we're only interested in Vars */
    if (!IsA(subqueryTargetEntry->expr, Var)) {
        return;
    }

    varToBeAdded = (Var*)subqueryTargetEntry->expr;

    /*
     *  "inh" flag is set either when inheritance or "UNION ALL" exists in the
     *  subquery. Here we're only interested in the "UNION ALL" case.
     *
     *  Else, we check one more thing: Does the subquery contain a "UNION" query.
     *  If so, we recursively traverse all "UNION" tree and add the corresponding
     *  target list elements to the attribute equivalence.
     *
     *  Finally, if it is a regular subquery (i.e., does not contain UNION or UNION ALL),
     *  we simply recurse to find the corresponding RTE_RELATION to add to the
     *  equivalence class.
     *
     *  Note that we're treating "UNION" and "UNION ALL" clauses differently given
     *  that postgres planner process/plans them separately.
     */
    if (rangeTableEntry->inh) {
        AddUnionAllSetOperationsToAttributeEquivalenceClass(attributeEquivalenceClass,
                                                            root, varToBeAdded);
    } else if (targetSubquery->setOperations) {
        AddUnionSetOperationsToAttributeEquivalenceClass(
            attributeEquivalenceClass, baseRelOptInfo->subroot,
            (SetOperationStmt*)targetSubquery->setOperations, varToBeAdded);
    } else if (varToBeAdded && IsA(varToBeAdded, Var) && varToBeAdded->varlevelsup == 0) {
        AddToAttributeEquivalenceClass(attributeEquivalenceClass, baseRelOptInfo->subroot,
                                       varToBeAdded);
    }
}

/*
 * GetTargetSubquery returns the corresponding subquery for the given planner root,
 * range table entry and the var.
 *
 * The aim of this function is to simplify extracting the subquery in case of "UNION ALL"
 * queries.
 */
static Query* GetTargetSubquery(PlannerInfo* root, RangeTblEntry* rangeTableEntry,
                                Var* varToBeAdded)
{
    Query* targetSubquery = NULL;

    /*
     * For subqueries other than "UNION ALL", find the corresponding targetSubquery. See
     * the details of how we process subqueries in the below comments.
     */
    if (!rangeTableEntry->inh) {
        RelOptInfo* baseRelOptInfo = find_base_rel(root, varToBeAdded->varno);

        /* If the targetSubquery was not planned, we have to punt */
        if (baseRelOptInfo->subroot == NULL) {
            return NULL;
        }

        Assert(IsA(baseRelOptInfo->subroot, PlannerInfo));

        targetSubquery = baseRelOptInfo->subroot->parse;
        Assert(IsA(targetSubquery, Query));
    } else {
        targetSubquery = rangeTableEntry->subquery;
    }

    return targetSubquery;
}

/*
 * IsRelOptOuterJoin returns true if the RelOpt referenced
 * by varNo is an outer join, false otherwise.
 */
bool IsRelOptOuterJoin(PlannerInfo* root, int varNo)
{
#if PG_VERSION_NUM >= PG_VERSION_16
    if (root->simple_rel_array_size <= varNo) {
        return true;
    }

    RelOptInfo* rel = root->simple_rel_array[varNo];
    if (rel == NULL) {
        /* must be an outer join */
        return true;
    }
#endif
    return false;
}

/*
 * AddUnionAllSetOperationsToAttributeEquivalenceClass recursively iterates on all the
 * append rels, sets the varno's accordingly and adds the
 * var the given equivalence class.
 */
static void AddUnionAllSetOperationsToAttributeEquivalenceClass(
    AttributeEquivalenceClass* attributeEquivalenceClass, PlannerInfo* root,
    Var* varToBeAdded)
{
    List* appendRelList = root->append_rel_list;
    ListCell* appendRelCell = NULL;

    /* iterate on the queries that are part of UNION ALL subqueries */
    foreach (appendRelCell, appendRelList) {
        AppendRelInfo* appendRelInfo = (AppendRelInfo*)lfirst(appendRelCell);

        /*
         * We're only interested in UNION ALL clauses and parent_reloid is invalid
         * only for UNION ALL (i.e., equals to a legitimate Oid for inheritance)
         */
        if (appendRelInfo->parent_reloid != InvalidOid) {
            continue;
        }
        int rtoffset = RangeTableOffsetCompat(root, appendRelInfo);
        int childRelId = appendRelInfo->child_relid - rtoffset;

        if (root->simple_rel_array_size <= childRelId) {
            /* we prefer to return over an Assert or error to be defensive */
            return;
        }

        RangeTblEntry* rte = root->simple_rte_array[childRelId];
        if (rte->inh) {
            /*
             * This code-path may require improvements. If a leaf of a UNION ALL
             * (e.g., an entry in appendRelList) itself is another UNION ALL
             * (e.g., rte->inh = true), the logic here might get into an infinite
             * recursion.
             *
             * The downside of "continue" here is that certain UNION ALL queries
             * that are safe to pushdown may not be pushed down.
             */
            continue;
        } else if (rte->rtekind == RTE_RELATION) {
            Index partitionKeyIndex = 0;
            List* translatedVars = TranslatedVarsForRteIdentity(GetRTEIdentity(rte));
            Var* varToBeAddedOnUnionAllSubquery = FindUnionAllVar(
                root, translatedVars, rte->relid, childRelId, &partitionKeyIndex);
            if (partitionKeyIndex == 0) {
                /* no partition key on the target list */
                continue;
            }

            if (attributeEquivalenceClass->unionQueryPartitionKeyIndex == 0) {
                /* the first partition key index we found */
                attributeEquivalenceClass->unionQueryPartitionKeyIndex =
                    partitionKeyIndex;
            } else if (attributeEquivalenceClass->unionQueryPartitionKeyIndex !=
                       partitionKeyIndex) {
                /*
                 * Partition keys on the leaves of the UNION ALL queries on
                 * different ordinal positions. We cannot pushdown, so skip.
                 */
                continue;
            }

            if (varToBeAddedOnUnionAllSubquery != NULL) {
                AddToAttributeEquivalenceClass(attributeEquivalenceClass, root,
                                               varToBeAddedOnUnionAllSubquery);
            }
        } else {
            /* set the varno accordingly for this specific child */
            varToBeAdded->varno = childRelId;

            AddToAttributeEquivalenceClass(attributeEquivalenceClass, root, varToBeAdded);
        }
    }
}

/*
 * ParentCountPriorToAppendRel returns the number of parents that come before
 * the given append rel info.
 */
static int ParentCountPriorToAppendRel(List* appendRelList,
                                       AppendRelInfo* targetAppendRelInfo)
{
    int targetParentIndex = targetAppendRelInfo->parent_relid;
    Bitmapset* parent_ids = NULL;
    AppendRelInfo* appendRelInfo = NULL;
    foreach_declared_ptr(appendRelInfo, appendRelList)
    {
        int curParentIndex = appendRelInfo->parent_relid;
        if (curParentIndex <= targetParentIndex) {
            parent_ids = bms_add_member(parent_ids, curParentIndex);
        }
    }
    return bms_num_members(parent_ids);
}

/*
 * AddUnionSetOperationsToAttributeEquivalenceClass recursively iterates on all the
 * setOperations and adds each corresponding target entry to the given equivalence
 * class.
 *
 * Although the function silently accepts INTERSECT and EXPECT set operations, they are
 * rejected later in the planning. We prefer this behavior to provide better error
 * messages.
 */
static void AddUnionSetOperationsToAttributeEquivalenceClass(
    AttributeEquivalenceClass* attributeEquivalenceClass, PlannerInfo* root,
    SetOperationStmt* setOperation, Var* varToBeAdded)
{
    List* rangeTableIndexList = NIL;
    ListCell* rangeTableIndexCell = NULL;

    ExtractRangeTableIndexWalker((Node*)setOperation, &rangeTableIndexList);

    foreach (rangeTableIndexCell, rangeTableIndexList) {
        int rangeTableIndex = lfirst_int(rangeTableIndexCell);

        varToBeAdded->varno = rangeTableIndex;
        AddToAttributeEquivalenceClass(attributeEquivalenceClass, root, varToBeAdded);
    }
}

/*
 * AddRteRelationToAttributeEquivalenceClass adds the given var to the given equivalence
 * class using the rteIdentity provided by the rangeTableEntry. Note that
 * rteIdentities are only assigned to RTE_RELATIONs and this function asserts
 * the input rte to be an RTE_RELATION.
 */
static void AddRteRelationToAttributeEquivalenceClass(
    AttributeEquivalenceClass* attrEquivalenceClass, RangeTblEntry* rangeTableEntry,
    Var* varToBeAdded)
{
    Oid relationId = rangeTableEntry->relid;

    /* we don't consider local tables in the equality on columns */
    if (!IsCitusTable(relationId)) {
        return;
    }

    Var* relationPartitionKey = DistPartitionKey(relationId);

    Assert(rangeTableEntry->rtekind == RTE_RELATION);

    /*
     * we only calculate the equivalence of distributed tables.
     * This leads to certain shortcomings in the query planning when reference
     * tables and/or intermediate results are involved in the query. For example,
     * the following query patterns could actually be pushed-down in a single iteration
     *    "(intermediate_res INNER JOIN dist dist1) INNER JOIN dist dist2 " or
     *    "(ref INNER JOIN dist dist1) JOIN dist dist2"
     *
     * However, if there are no explicit join conditions between distributed tables,
     * the planner cannot deduce the equivalence between the distributed tables.
     *
     * Instead, we should be able to track all the equivalences between range table
     * entries, and expand distributed table equivalences that happens via
     * reference table/intermediate results
     */
    if (relationPartitionKey == NULL) {
        return;
    }

    /* we're only interested in distribution columns */
    if (relationPartitionKey->varattno != varToBeAdded->varattno) {
        return;
    }

    AttributeEquivalenceClassMember* attributeEqMember =
        static_cast<AttributeEquivalenceClassMember*>(
            palloc0(sizeof(AttributeEquivalenceClassMember)));

    attributeEqMember->varattno = varToBeAdded->varattno;
    attributeEqMember->varno = varToBeAdded->varno;
    attributeEqMember->rteIdentity = GetRTEIdentity(rangeTableEntry);
    attributeEqMember->relationId = rangeTableEntry->relid;

    attrEquivalenceClass->equivalentAttributes =
        lappend(attrEquivalenceClass->equivalentAttributes, attributeEqMember);
}

/*
 * AttributeClassContainsAttributeClassMember returns true if it the input class member
 * is already exists in the attributeEquivalenceClass. An equality is identified by the
 * varattno and rteIdentity.
 */
static bool AttributeClassContainsAttributeClassMember(
    AttributeEquivalenceClassMember* inputMember,
    AttributeEquivalenceClass* attributeEquivalenceClass)
{
    ListCell* classCell = NULL;
    foreach (classCell, attributeEquivalenceClass->equivalentAttributes) {
        AttributeEquivalenceClassMember* memberOfClass =
            (AttributeEquivalenceClassMember*)lfirst(classCell);
        if (memberOfClass->rteIdentity == inputMember->rteIdentity &&
            memberOfClass->varattno == inputMember->varattno) {
            return true;
        }
    }

    return false;
}

/*
 * AddAttributeClassToAttributeClassList checks for certain properties of the
 * input attributeEquivalence before adding it to the attributeEquivalenceList.
 *
 * Firstly, the function skips adding NULL attributeEquivalence to the list.
 * Secondly, since an attribute equivalence class with a single member does
 * not contribute to our purposes, we skip such classed adding to the list.
 * Finally, we don't want to add an equivalence class whose exact equivalent
 * already exists in the list.
 */
static List* AddAttributeClassToAttributeClassList(
    List* attributeEquivalenceList, AttributeEquivalenceClass* attributeEquivalence)
{
    ListCell* attributeEquivalenceCell = NULL;

    if (attributeEquivalence == NULL) {
        return attributeEquivalenceList;
    }

    /*
     * Note that in some cases we allow having equivalentAttributes with zero or
     * one elements. For the details, see AddToAttributeEquivalenceClass().
     */
    List* equivalentAttributes = attributeEquivalence->equivalentAttributes;
    if (list_length(equivalentAttributes) < 2) {
        return attributeEquivalenceList;
    }

    /* we don't want to add an attributeEquivalence which already exists */
    foreach (attributeEquivalenceCell, attributeEquivalenceList) {
        AttributeEquivalenceClass* currentAttributeEquivalence =
            (AttributeEquivalenceClass*)lfirst(attributeEquivalenceCell);

        if (AttributeEquivalencesAreEqual(currentAttributeEquivalence,
                                          attributeEquivalence)) {
            return attributeEquivalenceList;
        }
    }

    attributeEquivalenceList = lappend(attributeEquivalenceList, attributeEquivalence);

    return attributeEquivalenceList;
}

/*
 *  AttributeEquivalencesAreEqual returns true if both input attribute equivalence
 *  classes contains exactly the same members.
 */
static bool AttributeEquivalencesAreEqual(
    AttributeEquivalenceClass* firstAttributeEquivalence,
    AttributeEquivalenceClass* secondAttributeEquivalence)
{
    List* firstEquivalenceMemberList = firstAttributeEquivalence->equivalentAttributes;
    List* secondEquivalenceMemberList = secondAttributeEquivalence->equivalentAttributes;
    ListCell* firstAttributeEquivalenceCell = NULL;
    ListCell* secondAttributeEquivalenceCell = NULL;

    if (list_length(firstEquivalenceMemberList) !=
        list_length(secondEquivalenceMemberList)) {
        return false;
    }

    foreach (firstAttributeEquivalenceCell, firstEquivalenceMemberList) {
        AttributeEquivalenceClassMember* firstEqMember =
            (AttributeEquivalenceClassMember*)lfirst(firstAttributeEquivalenceCell);
        bool foundAnEquivalentMember = false;

        foreach (secondAttributeEquivalenceCell, secondEquivalenceMemberList) {
            AttributeEquivalenceClassMember* secondEqMember =
                (AttributeEquivalenceClassMember*)lfirst(secondAttributeEquivalenceCell);

            if (firstEqMember->rteIdentity == secondEqMember->rteIdentity &&
                firstEqMember->varattno == secondEqMember->varattno) {
                foundAnEquivalentMember = true;
                break;
            }
        }

        /* we couldn't find an equivalent member */
        if (!foundAnEquivalentMember) {
            return false;
        }
    }

    return true;
}

/*
 * ContainsUnionSubquery gets a queryTree and returns true if the query
 * contains
 *      - a subquery with UNION set operation
 *      - no joins above the UNION set operation in the query tree
 *
 * Note that the function allows top level unions being wrapped into aggregations
 * queries and/or simple projection queries that only selects some fields from
 * the lower level queries.
 *
 * If there exists joins before the set operations, the function returns false.
 * Similarly, if the query does not contain any union set operations, the
 * function returns false.
 */
bool ContainsUnionSubquery(Query* queryTree)
{
    List* rangeTableList = queryTree->rtable;
    List* joinTreeTableIndexList = NIL;

    ExtractRangeTableIndexWalker((Node*)queryTree->jointree, &joinTreeTableIndexList);
    uint32 joiningRangeTableCount = list_length(joinTreeTableIndexList);

    /* don't allow joins on top of unions */
    if (joiningRangeTableCount > 1) {
        return false;
    }

    /* subquery without FROM */
    if (joiningRangeTableCount == 0) {
        return false;
    }

    Index subqueryRteIndex = linitial_int(joinTreeTableIndexList);
    RangeTblEntry* rangeTableEntry = rt_fetch(subqueryRteIndex, rangeTableList);
    if (rangeTableEntry->rtekind != RTE_SUBQUERY) {
        return false;
    }

    Query* subqueryTree = rangeTableEntry->subquery;
    Node* setOperations = subqueryTree->setOperations;
    if (setOperations != NULL) {
        SetOperationStmt* setOperationStatement = (SetOperationStmt*)setOperations;

        /*
         * Note that the set operation tree is traversed elsewhere for ensuring
         * that we only support UNIONs.
         */
        if (setOperationStatement->op != SETOP_UNION) {
            return false;
        }

        return true;
    }

    return ContainsUnionSubquery(subqueryTree);
}

/*
 * PartitionKeyForRTEIdentityInQuery finds the partition key var(if exists),
 * in the given original query for the rte that has targetRTEIndex.
 */
static Var* PartitionKeyForRTEIdentityInQuery(Query* originalQuery, int targetRTEIndex,
                                              Index* partitionKeyIndex)
{
    Query* originalQueryContainingRTEIdentity =
        FindQueryContainingRTEIdentity(originalQuery, targetRTEIndex);
    if (!originalQueryContainingRTEIdentity) {
        /*
         * We should always find the query but we have this check for sanity.
         * This check makes sure that if there is a bug while finding the query,
         * we don't get a crash etc. and the only downside will be we might be recursively
         * planning a query that could be pushed down.
         */
        return NULL;
    }

    /*
     * This approach fails to detect when
     * the top level query might have the column indexes in different order:
     * explain
     * SELECT count(*) FROM
     * (
     * SELECT user_id,value_2 FROM events_table
     * UNION
     * SELECT value_2, user_id FROM (SELECT user_id, value_2, random() FROM events_table)
     * as foo ) foobar; So we hit https://github.com/citusdata/citus/issues/5093.
     */
    List* relationTargetList = originalQueryContainingRTEIdentity->targetList;

    ListCell* targetEntryCell = NULL;
    Index partitionKeyTargetAttrIndex = 0;
    foreach (targetEntryCell, relationTargetList) {
        TargetEntry* targetEntry = (TargetEntry*)lfirst(targetEntryCell);
        Expr* targetExpression = targetEntry->expr;

        partitionKeyTargetAttrIndex++;

        bool skipOuterVars = false;
        if (!targetEntry->resjunk && IsA(targetExpression, Var) &&
            IsPartitionColumn(targetExpression, originalQueryContainingRTEIdentity,
                              skipOuterVars)) {
            Var* targetColumn = (Var*)targetExpression;

            /*
             * We find the referenced table column to support distribution
             * columns that are correlated.
             */
            RangeTblEntry* rteContainingPartitionKey = NULL;
            FindReferencedTableColumn(targetExpression, NIL,
                                      originalQueryContainingRTEIdentity, &targetColumn,
                                      &rteContainingPartitionKey, skipOuterVars);

            if (rteContainingPartitionKey->rtekind == RTE_RELATION &&
                GetRTEIdentity(rteContainingPartitionKey) == targetRTEIndex) {
                *partitionKeyIndex = partitionKeyTargetAttrIndex;
                return (Var*)copyObject(targetColumn);
            }
        }
    }

    return NULL;
}

/*
 * FindQueryContainingRTEIdentity finds the query/subquery that has an RTE
 * with rteIndex in its rtable.
 */
static Query* FindQueryContainingRTEIdentity(Query* query, int rteIndex)
{
    FindQueryContainingRteIdentityContext* findRteIdentityContext =
        static_cast<FindQueryContainingRteIdentityContext*>(
            palloc0(sizeof(FindQueryContainingRteIdentityContext)));
    findRteIdentityContext->targetRTEIdentity = rteIndex;
    FindQueryContainingRTEIdentityInternal((Node*)query, findRteIdentityContext);
    return findRteIdentityContext->query;
}

/*
 * FindQueryContainingRTEIdentityInternal walks on the given node to find a query
 * which has an RTE that has a given rteIdentity.
 */
static bool FindQueryContainingRTEIdentityInternal(
    Node* node, FindQueryContainingRteIdentityContext* context)
{
    if (node == NULL) {
        return false;
    }
    if (IsA(node, Query)) {
        Query* query = (Query*)node;
        Query* parentQuery = context->query;
        context->query = query;
        if (query_tree_walker(query, walker_cast0(FindQueryContainingRTEIdentityInternal),
                              context, QTW_EXAMINE_RTES)) {
            return true;
        }
        context->query = parentQuery;
        return false;
    }

    if (!IsA(node, RangeTblEntry)) {
        return expression_tree_walker(
            node, walker_cast0(FindQueryContainingRTEIdentityInternal), context);
    }
    RangeTblEntry* rte = (RangeTblEntry*)node;
    if (rte->rtekind == RTE_RELATION) {
        if (GetRTEIdentity(rte) == context->targetRTEIdentity) {
            return true;
        }
    }
    return false;
}

/*
 * AllDistributedRelationsInRestrictionContextColocated determines whether all of the
 * distributed  relations in the given relation restrictions list are co-located.
 */
static bool AllDistributedRelationsInRestrictionContextColocated(
    RelationRestrictionContext* restrictionContext)
{
    RelationRestriction* relationRestriction = NULL;
    List* relationIdList = NIL;

    /* check whether all relations exists in the main restriction list */
    foreach_declared_ptr(relationRestriction, restrictionContext->relationRestrictionList)
    {
        relationIdList = lappend_oid(relationIdList, relationRestriction->relationId);
    }

    return AllDistributedRelationsInListColocated(relationIdList);
}

/*
 * AllDistributedRelationsInRTEListColocated determines whether all of the
 * distributed relations in the given RangeTableEntry list are co-located.
 */
bool AllDistributedRelationsInRTEListColocated(List* rangeTableEntryList)
{
    RangeTblEntry* rangeTableEntry = NULL;
    List* relationIdList = NIL;

    foreach_declared_ptr(rangeTableEntry, rangeTableEntryList)
    {
        relationIdList = lappend_oid(relationIdList, rangeTableEntry->relid);
    }

    return AllDistributedRelationsInListColocated(relationIdList);
}

/*
 * AllDistributedRelationsInListColocated determines whether all of the
 * distributed relations in the given list are co-located.
 */
bool AllDistributedRelationsInListColocated(List* relationList)
{
    int initialColocationId = INVALID_COLOCATION_ID;
    Oid relationId = InvalidOid;

    foreach_declared_oid(relationId, relationList)
    {
        if (!IsCitusTable(relationId)) {
            /* not interested in Postgres tables */
            continue;
        }

        if (!IsCitusTableType(relationId, DISTRIBUTED_TABLE)) {
            /* not interested in non-distributed tables */
            continue;
        }

        if (IsCitusTableType(relationId, APPEND_DISTRIBUTED)) {
            /*
             * If we got to this point, it means there are multiple distributed
             * relations and at least one of them is append-distributed. Since
             * we do not consider append-distributed tables to be co-located,
             * we can immediately return false.
             */
            return false;
        }

        int colocationId = TableColocationId(relationId);

        if (initialColocationId == INVALID_COLOCATION_ID) {
            initialColocationId = colocationId;
        } else if (colocationId != initialColocationId) {
            return false;
        }
    }

    return true;
}

/*
 * RelationIdList returns list of unique relation ids in query tree.
 */
List* DistributedRelationIdList(Query* query)
{
    List* rangeTableList = NIL;
    List* relationIdList = NIL;
    ListCell* tableEntryCell = NULL;

    ExtractRangeTableRelationWalker((Node*)query, &rangeTableList);
    List* tableEntryList = TableEntryList(rangeTableList);

    foreach (tableEntryCell, tableEntryList) {
        TableEntry* tableEntry = (TableEntry*)lfirst(tableEntryCell);
        Oid relationId = tableEntry->relationId;

        if (!IsCitusTable(relationId)) {
            continue;
        }

        relationIdList = list_append_unique_oid(relationIdList, relationId);
    }

    return relationIdList;
}

/*
 * FilterPlannerRestrictionForQuery gets a planner restriction context and
 * set of rte identities. It returns the restrictions that that appear
 * in the queryRteIdentities and returns a newly allocated
 * PlannerRestrictionContext. The function also sets all the other fields of
 * the PlannerRestrictionContext with respect to the filtered restrictions.
 */
PlannerRestrictionContext* FilterPlannerRestrictionForQuery(
    PlannerRestrictionContext* plannerRestrictionContext, Query* query)
{
    Relids queryRteIdentities = QueryRteIdentities(query);

    RelationRestrictionContext* relationRestrictionContext =
        plannerRestrictionContext->relationRestrictionContext;
    JoinRestrictionContext* joinRestrictionContext =
        plannerRestrictionContext->joinRestrictionContext;

    RelationRestrictionContext* filteredRelationRestrictionContext =
        FilterRelationRestrictionContext(relationRestrictionContext, queryRteIdentities);

    JoinRestrictionContext* filtererdJoinRestrictionContext =
        FilterJoinRestrictionContext(joinRestrictionContext, queryRteIdentities);

    /* allocate the filtered planner restriction context and set all the fields */
    PlannerRestrictionContext* filteredPlannerRestrictionContext =
        static_cast<PlannerRestrictionContext*>(
            palloc0(sizeof(PlannerRestrictionContext)));
    filteredPlannerRestrictionContext->fastPathRestrictionContext =
        static_cast<FastPathRestrictionContext*> palloc0(
            sizeof(FastPathRestrictionContext));

    filteredPlannerRestrictionContext->memoryContext =
        plannerRestrictionContext->memoryContext;

    int totalRelationCount =
        UniqueRelationCount(filteredRelationRestrictionContext, ANY_CITUS_TABLE_TYPE);
    int referenceRelationCount =
        UniqueRelationCount(filteredRelationRestrictionContext, REFERENCE_TABLE);

    filteredRelationRestrictionContext->allReferenceTables =
        (totalRelationCount == referenceRelationCount);

    /* finally set the relation and join restriction contexts */
    filteredPlannerRestrictionContext->relationRestrictionContext =
        filteredRelationRestrictionContext;
    filteredPlannerRestrictionContext->joinRestrictionContext =
        filtererdJoinRestrictionContext;

    return filteredPlannerRestrictionContext;
}

/*
 * GetRestrictInfoListForRelation gets a range table entry and planner
 * restriction context. The function returns a list of expressions that
 * appear in the restriction context for only the given relation. And,
 * all the varnos are set to 1.
 */
List* GetRestrictInfoListForRelation(RangeTblEntry* rangeTblEntry,
                                     PlannerRestrictionContext* plannerRestrictionContext)
{
    RelationRestriction* relationRestriction =
        RelationRestrictionForRelation(rangeTblEntry, plannerRestrictionContext);
    if (relationRestriction == NULL) {
        return NIL;
    }

    RelOptInfo* relOptInfo = relationRestriction->relOptInfo;
    List* baseRestrictInfo = relOptInfo->baserestrictinfo;

    bool joinConditionIsOnFalse = JoinConditionIsOnFalse(relOptInfo->joininfo);
    if (joinConditionIsOnFalse) {
        /* found WHERE false, no need  to continue, we just return a false clause */
        bool value = false;
        bool isNull = false;
        Node* falseClause = makeBoolConst(value, isNull);
        return list_make1(falseClause);
    }

    List* restrictExprList = NIL;
    RestrictInfo* restrictInfo = NULL;
    foreach_declared_ptr(restrictInfo, baseRestrictInfo)
    {
        Expr* restrictionClause = restrictInfo->clause;

        /*
         * we cannot process some restriction clauses because they are not
         * safe to recursively plan.
         */
        if (FindNodeMatchingCheckFunction((Node*)restrictionClause,
                                          IsNotSafeRestrictionToRecursivelyPlan)) {
            continue;
        }

        /*
         * If the restriction involves multiple tables, we cannot add it to
         * input relation's expression list.
         */
        Relids varnos = pull_varnos((Node*)(restrictionClause));
        if (bms_num_members(varnos) != 1) {
            continue;
        }

        /*
         * PlaceHolderVar is not relevant to be processed inside a restriction clause.
         * Otherwise, pull_var_clause_default would throw error. PG would create
         * the restriction to physical Var that PlaceHolderVar points anyway, so it is
         * safe to skip this restriction.
         */
        if (FindNodeMatchingCheckFunction((Node*)restrictionClause, HasPlaceHolderVar)) {
            continue;
        }

        /*
         * We're going to add this restriction expression to a subquery
         * which consists of only one relation in its jointree. Thus,
         * simply set the varnos accordingly.
         */
        Expr* copyOfRestrictClause = (Expr*)copyObject((Node*)restrictionClause);
        List* varClauses = pull_var_clause_default((Node*)copyOfRestrictClause);
        Var* column = NULL;
        foreach_declared_ptr(column, varClauses)
        {
            column->varno = SINGLE_RTE_INDEX;
        }

        restrictExprList = lappend(restrictExprList, copyOfRestrictClause);
    }

    return restrictExprList;
}

/*
 * RelationRestrictionForRelation gets the relation restriction for the given
 * range table entry.
 */
RelationRestriction* RelationRestrictionForRelation(
    RangeTblEntry* rangeTableEntry, PlannerRestrictionContext* plannerRestrictionContext)
{
    int rteIdentity = GetRTEIdentity(rangeTableEntry);
    RelationRestrictionContext* relationRestrictionContext =
        plannerRestrictionContext->relationRestrictionContext;
    Relids queryRteIdentities = bms_make_singleton(rteIdentity);
    RelationRestrictionContext* filteredRelationRestrictionContext =
        FilterRelationRestrictionContext(relationRestrictionContext, queryRteIdentities);
    List* filteredRelationRestrictionList =
        filteredRelationRestrictionContext->relationRestrictionList;

    if (list_length(filteredRelationRestrictionList) < 1) {
        return NULL;
    }

    RelationRestriction* relationRestriction =
        (RelationRestriction*)linitial(filteredRelationRestrictionList);
    return relationRestriction;
}

/*
 * IsNotSafeRestrictionToRecursivelyPlan returns true if the given node
 * is not a safe restriction to be recursivelly planned.
 */
static bool IsNotSafeRestrictionToRecursivelyPlan(Node* node)
{
    if (IsA(node, Param) || IsA(node, SubLink) || IsA(node, SubPlan) ||
        IsA(node, AlternativeSubPlan)) {
        return true;
    }
    return false;
}

/*
 * HasPlaceHolderVar returns true if given node contains any PlaceHolderVar.
 */
static bool HasPlaceHolderVar(Node* node)
{
    return IsA(node, PlaceHolderVar);
}

/*
 * FilterRelationRestrictionContext gets a relation restriction context and
 * set of rte identities. It returns the relation restrictions that that appear
 * in the queryRteIdentities and returns a newly allocated
 * RelationRestrictionContext.
 */
RelationRestrictionContext* FilterRelationRestrictionContext(
    RelationRestrictionContext* relationRestrictionContext, Relids queryRteIdentities)
{
    RelationRestrictionContext* filteredRestrictionContext =
        static_cast<RelationRestrictionContext*>(
            palloc0(sizeof(RelationRestrictionContext)));

    ListCell* relationRestrictionCell = NULL;

    foreach (relationRestrictionCell,
             relationRestrictionContext->relationRestrictionList) {
        RelationRestriction* relationRestriction =
            (RelationRestriction*)lfirst(relationRestrictionCell);

        int rteIdentity = GetRTEIdentity(relationRestriction->rte);

        if (bms_is_member(rteIdentity, queryRteIdentities)) {
            filteredRestrictionContext->relationRestrictionList = lappend(
                filteredRestrictionContext->relationRestrictionList, relationRestriction);
        }
    }

    return filteredRestrictionContext;
}

/*
 * FilterJoinRestrictionContext gets a join restriction context and
 * set of rte identities. It returns the join restrictions that that appear
 * in the queryRteIdentities and returns a newly allocated
 * JoinRestrictionContext.
 *
 * Note that the join restriction is added to the return context as soon as
 * any range table entry that appear in the join belongs to queryRteIdentities.
 */
static JoinRestrictionContext* FilterJoinRestrictionContext(
    JoinRestrictionContext* joinRestrictionContext, Relids queryRteIdentities)
{
    JoinRestrictionContext* filtererdJoinRestrictionContext =
        static_cast<JoinRestrictionContext*>(palloc0(sizeof(JoinRestrictionContext)));

    ListCell* joinRestrictionCell = NULL;

    foreach (joinRestrictionCell, joinRestrictionContext->joinRestrictionList) {
        JoinRestriction* joinRestriction = (JoinRestriction*)lfirst(joinRestrictionCell);
        RangeTblEntry** rangeTableEntries =
            joinRestriction->plannerInfo->simple_rte_array;
        int rangeTableArrayLength = joinRestriction->plannerInfo->simple_rel_array_size;

        if (RangeTableArrayContainsAnyRTEIdentities(
                rangeTableEntries, rangeTableArrayLength, queryRteIdentities)) {
            filtererdJoinRestrictionContext->joinRestrictionList = lappend(
                filtererdJoinRestrictionContext->joinRestrictionList, joinRestriction);
        }
    }

    /*
     * No need to re calculate has join fields as we are still operating on
     * the same query and as these values are calculated per-query basis.
     */
    filtererdJoinRestrictionContext->hasSemiJoin = joinRestrictionContext->hasSemiJoin;
    filtererdJoinRestrictionContext->hasOuterJoin = joinRestrictionContext->hasOuterJoin;

    return filtererdJoinRestrictionContext;
}

/*
 * RangeTableArrayContainsAnyRTEIdentities returns true if any of the range table entries
 * int rangeTableEntries array is an range table relation specified in queryRteIdentities.
 */
static bool RangeTableArrayContainsAnyRTEIdentities(RangeTblEntry** rangeTableEntries,
                                                    int rangeTableArrayLength,
                                                    Relids queryRteIdentities)
{
    /* simple_rte_array starts from 1, see plannerInfo struct */
    for (int rteIndex = 1; rteIndex < rangeTableArrayLength; ++rteIndex) {
        RangeTblEntry* rangeTableEntry = rangeTableEntries[rteIndex];
        List* rangeTableRelationList = NULL;
        ListCell* rteRelationCell = NULL;

        /*
         * Get list of all RTE_RELATIONs in the given range table entry
         * (i.e.,rangeTableEntry could be a subquery where we're interested
         * in relations).
         */
        if (rangeTableEntry->rtekind == RTE_SUBQUERY) {
            ExtractRangeTableRelationWalker((Node*)rangeTableEntry->subquery,
                                            &rangeTableRelationList);
        } else if (rangeTableEntry->rtekind == RTE_RELATION) {
            ExtractRangeTableRelationWalker((Node*)rangeTableEntry,
                                            &rangeTableRelationList);
        } else {
            /* we currently do not accept any other RTE types here */
            continue;
        }

        foreach (rteRelationCell, rangeTableRelationList) {
            RangeTblEntry* rteRelation = (RangeTblEntry*)lfirst(rteRelationCell);

            Assert(rteRelation->rtekind == RTE_RELATION);

            int rteIdentity = GetRTEIdentity(rteRelation);
            if (bms_is_member(rteIdentity, queryRteIdentities)) {
                return true;
            }
        }
    }

    return false;
}

/*
 * QueryRteIdentities gets a queryTree, find get all the rte identities assigned by
 * us.
 */
static Relids QueryRteIdentities(Query* queryTree)
{
    List* rangeTableList = NULL;
    ListCell* rangeTableCell = NULL;
    Relids queryRteIdentities = NULL;

    /* extract range table entries for simple relations only */
    ExtractRangeTableRelationWalker((Node*)queryTree, &rangeTableList);

    foreach (rangeTableCell, rangeTableList) {
        RangeTblEntry* rangeTableEntry = (RangeTblEntry*)lfirst(rangeTableCell);

        /* we're only interested in relations */
        Assert(rangeTableEntry->rtekind == RTE_RELATION);

        int rteIdentity = GetRTEIdentity(rangeTableEntry);

        queryRteIdentities = bms_add_member(queryRteIdentities, rteIdentity);
    }

    return queryRteIdentities;
}
